In [21]:
import numpy as np
import pandas as pd
import json
import sys
import os
import matplotlib
matplotlib.use('Agg') 
import matplotlib.pyplot as plt
import seaborn as sns
import pdb

from util import utils as data_utils

%pylab inline
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'Blues'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

json_file = './cifar_results/noise_45/var_bootstrap_lr_001/checkpoint_100.json'
FDIR = os.path.dirname(json_file)
NUM_CLASSIFY = 10
Populating the interactive namespace from numpy and matplotlib
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [ ]:
 
In [22]:
# Plot gradients norms for the entire learning process
grads_json_filename = os.path.join(FDIR, 'model_grads.json')
grads = [[], [], []]
grads_key = ['max_grad_w1_16', 'max_grad_w1_32', 'max_grad_w1_64']
if os.path.exists(grads_json_filename):
    with open(grads_json_filename, 'r') as fp:
        data = json.load(fp)
        for i, k in enumerate(grads_key):
            if data[0].get(k, None) is None:
                continue
            for batch_grads in data:
                grads[i].append(batch_grads[k])

def plot_grads(grads, title, x_label, y_label, figsize=(10, 8)):
    plt.figure(figsize=figsize)
    # plt.subplot(2, 1, 1)
    plt.plot(grads)
    plt.title(title)
    plt.ylabel(y_label)
    plt.xlabel(x_label)
    
for i, g in enumerate(grads):
    if len(g) > 0:
        plot_grads(g, grads_key[i], 'iterations', grads_key[i])
        # pass
In [23]:
with open(json_file, 'r') as fp:
    data = json.load(fp)
# Loss history might not be of equal length.
train_loss_hist = data['train_loss_history']
val_loss_hist = data['val_loss_history']

# pdb.set_trace()
def plot_loss_hist(loss_hist, title,):
    plt.figure(figsize=(5,4))
    plt.subplot(1, 1, 1)
    plt.plot(loss_hist)
    plt.title(title)  # Train Loss
    plt.ylabel('loss')
    plt.xlabel('time')
    plt.show()
    
plot_loss_hist(train_loss_hist, 'Train Loss')
plot_loss_hist(val_loss_hist, 'Val loss')

if data.get('crit1_loss_history', None) is not None:
    plot_loss_hist(data['crit1_loss_history'], 'Target criterion loss')

if data.get('crit2_loss_history', None) is not None and \
    len(data['crit2_loss_history']) > 0:
    plot_loss_hist(data['crit2_loss_history'], 'Pred criterion loss')

if data.get('pred_loss_history', None) is not None and \
    len(data['pred_loss_history']) > 0:
    plot_loss_hist(data['pred_loss_history'], 'Total Pred loss (beta*t + (1-beta)*p)')    

if data.get('beta_loss_history', None) is not None and \
    len(data['beta_loss_history']) > 0:
    plot_loss_hist(data['beta_loss_history'], 'Beta loss')
In [24]:
if data.get('KL_loss_history', None) is not None:
    # Loss history might not be of equal length.
    KL_loss_hist = data['KL_loss_history']

    plt.figure(figsize=(10,8))
    plt.plot(KL_loss_hist)
    plt.title('KL loss')
    plt.ylabel('loss')
    plt.xlabel('time')
    plt.show()
In [25]:
def get_conf(json_file, num_classes=26, json_key='conf'):
    with open(json_file, 'r') as fp:
        data = json.load(fp)
        conf = data.get(json_key, None)
    if conf is None:
        return
    # c1 = conf.split('\n')[1].split("]")[0].split("[ ")[1].split(" ")
    c1 = conf.split('\n')
    # print(c1)
    conf_mat, row_idx = np.zeros((num_classes, num_classes)), 0
    for i in c1:
        #pdb.set_trace()
        is_conf_row = False
        if ']' in i and '[[' in i:
            val = i.split(']')[0].split('[[')[1].split(' ')
            is_conf_row = True
        elif ']' in i and '[' in i:
            val = i.split(']')[0].split('[')[1].split(' ')
            is_conf_row = True
        if is_conf_row:
            col_idx = 0
            for v in val:
                if not len(v):
                    continue
                try:
                    conf_mat[row_idx, col_idx] = int(v)
                    col_idx = col_idx + 1
                except:
                    continue
            row_idx = row_idx + 1
    
    assert(row_idx == num_classes)
    conf_mat = conf_mat.astype(int)
    fdir = os.path.dirname(json_file)
    json_name = os.path.basename(json_file)[:-5]
    conf_file_name = fdir + '/' + 'conf_' + json_name + '.txt'
    np.savetxt(conf_file_name, conf_mat, fmt='%d', delimiter=', ')
    return conf_mat


def plot_conf(norm_conf):
  # Plot using seaborn
  # (this is style I used for ResNet matrix)
  plt.figure(figsize=(10,6))
  df_cm = pd.DataFrame(norm_conf)
  sns.heatmap(df_cm, annot=True, cmap="Blues")
  plt.show()
In [26]:
def get_sorted_checkpoints(fdir):
    # Checkpoint files are named as 'checkpoint_%d.json'
    checkpoint_map = {}
    for f in os.listdir(fdir):
        if f.endswith('json') and f.startswith('checkpoint'):
            checkpoint_num = int(f.split('checkpoint_')[-1].split('.')[0])
            checkpoint_map[checkpoint_num] = f
    sorted_checkpoints = []
    for k in sorted(checkpoint_map.keys()):
        v = checkpoint_map[k]
        sorted_checkpoints.append(v)
    return sorted_checkpoints
In [27]:
def best_f_scores(fdir, num_classes=5): 
    best_checkpoints = [None, None, None]
    best_3_fscores = [0, 0, 0]
    best_confs = [np.array(()), np.array(()), np.array(())]
    f1_weight_list = [1.0] * num_classes
    f1_weights = np.array(f1_weight_list)
    sorted_checkpoint_files = get_sorted_checkpoints(fdir)
    for f in sorted_checkpoint_files:
        json_file = fdir + '/' + f
        conf = get_conf(json_file, num_classes, json_key='val_conf')
        norm_conf = data_utils.normalize_conf(conf)
        f1 = data_utils.get_f1_score(conf, f1_weights)
        kappa = data_utils.computeKappa(conf)
        wt_f1 = data_utils.computeWeightedF1(conf)
        print('file: {}, f1: {:.3f}, kappa: {:.3f}, weighted-F1: {:.3f}'.format(
                f, f1, kappa, wt_f1))
        plot_conf(norm_conf)
        max_idx = -1
        for i in range(len(best_3_fscores)):
            if best_3_fscores[i] > f1:
                break
            max_idx = i
        for j in range(max_idx):
            best_3_fscores[j] = best_3_fscores[j+1]
            best_confs[j] = best_confs[j+1]
            best_checkpoints[j] = best_checkpoints[j+1]

        best_3_fscores[max_idx] = f1
        best_confs[max_idx] = conf
        best_checkpoints[max_idx] = f

    return best_3_fscores, best_confs, best_checkpoints
In [28]:
def plot_train_conf(fdir, num_classes=5):
    sorted_checkpoint_files = get_sorted_checkpoints(fdir)
    if len(sorted_checkpoint_files) > 0:
        last_checkpoint = sorted_checkpoint_files[-1]
        json_file = fdir + '/' + last_checkpoint
        conf = get_conf(json_file, num_classes=num_classes, json_key='train_conf')
        print(conf)
        norm_conf = data_utils.normalize_conf(conf)
        f1_weight_list = [1.0] * num_classes
        f1_weights = np.array(f1_weight_list)
        f1 = data_utils.get_f1_score(conf, f1_weights)
        kappa = data_utils.computeKappa(conf)
        wt_f1 = data_utils.computeWeightedF1(conf)
        print('file: {}, f1: {:.3f}, kappa: {:.3f}, weighted-F1: {:.3f}'.format(
            f, f1, kappa, wt_f1))
        plot_conf(norm_conf)

plot_train_conf(FDIR, num_classes=NUM_CLASSIFY)
[[3200   20  110   58   32   10 1446   10   60   77]
 [  30 3429    1   12    2    5    8    4   16 1452]
 [ 135    1 2986  248  104 1314   91   71    6    9]
 [  44   24   88 2898   70  353   80   36 1424   27]
 [  21    2   94   93 3191   72   80 1423    3    9]
 [  49    3 1434  395  144 2826   68  106   10    0]
 [  45    4   87  110 1341   50 3281   29    1    1]
 [  19 1459   16   60  110   84    7 3218    7   32]
 [  73   14   29 1249   28  146   45   18 3322   32]
 [1415  138   25   25   10    7   12   26   60 3351]]
file: <built-in method f of mtrand.RandomState object at 0x7f40522f3230>, f1: 0.634, kappa: 0.421, weighted-F1: 0.634
In [29]:
best_f_scores(FDIR, num_classes=NUM_CLASSIFY)
file: checkpoint_1.json, f1: 0.089, kappa: -0.023, weighted-F1: 0.089
file: checkpoint_2.json, f1: 0.235, kappa: -0.061, weighted-F1: 0.235
file: checkpoint_3.json, f1: 0.327, kappa: 0.038, weighted-F1: 0.327
file: checkpoint_4.json, f1: 0.506, kappa: 0.320, weighted-F1: 0.506
file: checkpoint_5.json, f1: 0.565, kappa: 0.581, weighted-F1: 0.565
file: checkpoint_6.json, f1: 0.624, kappa: 0.619, weighted-F1: 0.624
file: checkpoint_7.json, f1: 0.668, kappa: 0.677, weighted-F1: 0.668
file: checkpoint_8.json, f1: 0.707, kappa: 0.706, weighted-F1: 0.707
file: checkpoint_9.json, f1: 0.756, kappa: 0.763, weighted-F1: 0.756
file: checkpoint_10.json, f1: 0.746, kappa: 0.736, weighted-F1: 0.746
file: checkpoint_11.json, f1: 0.775, kappa: 0.753, weighted-F1: 0.775
file: checkpoint_12.json, f1: 0.776, kappa: 0.785, weighted-F1: 0.776
file: checkpoint_13.json, f1: 0.809, kappa: 0.806, weighted-F1: 0.809
file: checkpoint_14.json, f1: 0.791, kappa: 0.796, weighted-F1: 0.791
file: checkpoint_15.json, f1: 0.816, kappa: 0.800, weighted-F1: 0.816
file: checkpoint_16.json, f1: 0.821, kappa: 0.823, weighted-F1: 0.821
file: checkpoint_17.json, f1: 0.813, kappa: 0.823, weighted-F1: 0.813
file: checkpoint_18.json, f1: 0.823, kappa: 0.822, weighted-F1: 0.823
file: checkpoint_19.json, f1: 0.844, kappa: 0.834, weighted-F1: 0.844
file: checkpoint_20.json, f1: 0.833, kappa: 0.833, weighted-F1: 0.833
file: checkpoint_21.json, f1: 0.844, kappa: 0.840, weighted-F1: 0.844
file: checkpoint_22.json, f1: 0.840, kappa: 0.842, weighted-F1: 0.840
file: checkpoint_23.json, f1: 0.841, kappa: 0.841, weighted-F1: 0.841
file: checkpoint_24.json, f1: 0.842, kappa: 0.842, weighted-F1: 0.842
file: checkpoint_25.json, f1: 0.843, kappa: 0.843, weighted-F1: 0.843
file: checkpoint_26.json, f1: 0.842, kappa: 0.843, weighted-F1: 0.842
file: checkpoint_27.json, f1: 0.845, kappa: 0.843, weighted-F1: 0.845
file: checkpoint_28.json, f1: 0.843, kappa: 0.844, weighted-F1: 0.843
file: checkpoint_29.json, f1: 0.847, kappa: 0.847, weighted-F1: 0.847
file: checkpoint_30.json, f1: 0.844, kappa: 0.844, weighted-F1: 0.844
file: checkpoint_31.json, f1: 0.844, kappa: 0.843, weighted-F1: 0.844
file: checkpoint_32.json, f1: 0.848, kappa: 0.848, weighted-F1: 0.848
file: checkpoint_33.json, f1: 0.848, kappa: 0.845, weighted-F1: 0.848
file: checkpoint_34.json, f1: 0.846, kappa: 0.850, weighted-F1: 0.846
file: checkpoint_35.json, f1: 0.844, kappa: 0.844, weighted-F1: 0.844
file: checkpoint_36.json, f1: 0.845, kappa: 0.844, weighted-F1: 0.845
file: checkpoint_37.json, f1: 0.849, kappa: 0.847, weighted-F1: 0.849
file: checkpoint_38.json, f1: 0.851, kappa: 0.848, weighted-F1: 0.851
file: checkpoint_39.json, f1: 0.847, kappa: 0.846, weighted-F1: 0.847
file: checkpoint_40.json, f1: 0.848, kappa: 0.848, weighted-F1: 0.848
file: checkpoint_41.json, f1: 0.851, kappa: 0.849, weighted-F1: 0.851
file: checkpoint_42.json, f1: 0.849, kappa: 0.846, weighted-F1: 0.849
file: checkpoint_43.json, f1: 0.849, kappa: 0.847, weighted-F1: 0.849
file: checkpoint_44.json, f1: 0.850, kappa: 0.849, weighted-F1: 0.850
file: checkpoint_45.json, f1: 0.853, kappa: 0.852, weighted-F1: 0.853
file: checkpoint_46.json, f1: 0.853, kappa: 0.852, weighted-F1: 0.853
file: checkpoint_47.json, f1: 0.850, kappa: 0.847, weighted-F1: 0.850
file: checkpoint_48.json, f1: 0.848, kappa: 0.849, weighted-F1: 0.848
file: checkpoint_49.json, f1: 0.850, kappa: 0.847, weighted-F1: 0.850
file: checkpoint_50.json, f1: 0.849, kappa: 0.846, weighted-F1: 0.849
file: checkpoint_51.json, f1: 0.855, kappa: 0.855, weighted-F1: 0.855
file: checkpoint_52.json, f1: 0.851, kappa: 0.852, weighted-F1: 0.851
file: checkpoint_53.json, f1: 0.851, kappa: 0.849, weighted-F1: 0.851
file: checkpoint_54.json, f1: 0.851, kappa: 0.848, weighted-F1: 0.851
file: checkpoint_55.json, f1: 0.852, kappa: 0.853, weighted-F1: 0.852
file: checkpoint_56.json, f1: 0.854, kappa: 0.851, weighted-F1: 0.854
file: checkpoint_57.json, f1: 0.851, kappa: 0.851, weighted-F1: 0.851
file: checkpoint_58.json, f1: 0.850, kappa: 0.852, weighted-F1: 0.850
file: checkpoint_59.json, f1: 0.850, kappa: 0.847, weighted-F1: 0.850
file: checkpoint_60.json, f1: 0.852, kappa: 0.854, weighted-F1: 0.852
file: checkpoint_61.json, f1: 0.852, kappa: 0.853, weighted-F1: 0.852
file: checkpoint_62.json, f1: 0.850, kappa: 0.852, weighted-F1: 0.850
file: checkpoint_63.json, f1: 0.854, kappa: 0.852, weighted-F1: 0.854
file: checkpoint_64.json, f1: 0.851, kappa: 0.848, weighted-F1: 0.851
file: checkpoint_65.json, f1: 0.854, kappa: 0.853, weighted-F1: 0.854
file: checkpoint_66.json, f1: 0.853, kappa: 0.854, weighted-F1: 0.853
file: checkpoint_67.json, f1: 0.852, kappa: 0.854, weighted-F1: 0.852
file: checkpoint_68.json, f1: 0.855, kappa: 0.856, weighted-F1: 0.855
file: checkpoint_69.json, f1: 0.854, kappa: 0.852, weighted-F1: 0.854
file: checkpoint_70.json, f1: 0.853, kappa: 0.854, weighted-F1: 0.853
file: checkpoint_71.json, f1: 0.854, kappa: 0.853, weighted-F1: 0.854
file: checkpoint_72.json, f1: 0.854, kappa: 0.852, weighted-F1: 0.854
file: checkpoint_73.json, f1: 0.853, kappa: 0.851, weighted-F1: 0.853
file: checkpoint_74.json, f1: 0.856, kappa: 0.854, weighted-F1: 0.856
file: checkpoint_75.json, f1: 0.854, kappa: 0.852, weighted-F1: 0.854
file: checkpoint_76.json, f1: 0.854, kappa: 0.852, weighted-F1: 0.854
file: checkpoint_77.json, f1: 0.857, kappa: 0.855, weighted-F1: 0.857
file: checkpoint_78.json, f1: 0.854, kappa: 0.852, weighted-F1: 0.854
file: checkpoint_79.json, f1: 0.852, kappa: 0.853, weighted-F1: 0.852
file: checkpoint_80.json, f1: 0.855, kappa: 0.854, weighted-F1: 0.855
file: checkpoint_81.json, f1: 0.855, kappa: 0.853, weighted-F1: 0.855
file: checkpoint_82.json, f1: 0.854, kappa: 0.855, weighted-F1: 0.854
file: checkpoint_83.json, f1: 0.856, kappa: 0.854, weighted-F1: 0.856
file: checkpoint_84.json, f1: 0.853, kappa: 0.853, weighted-F1: 0.853
file: checkpoint_85.json, f1: 0.855, kappa: 0.852, weighted-F1: 0.855
file: checkpoint_86.json, f1: 0.853, kappa: 0.852, weighted-F1: 0.853
file: checkpoint_87.json, f1: 0.854, kappa: 0.854, weighted-F1: 0.854
file: checkpoint_88.json, f1: 0.854, kappa: 0.852, weighted-F1: 0.854
file: checkpoint_89.json, f1: 0.853, kappa: 0.853, weighted-F1: 0.853
file: checkpoint_90.json, f1: 0.855, kappa: 0.854, weighted-F1: 0.855
file: checkpoint_91.json, f1: 0.855, kappa: 0.851, weighted-F1: 0.855
file: checkpoint_92.json, f1: 0.853, kappa: 0.853, weighted-F1: 0.853
file: checkpoint_93.json, f1: 0.854, kappa: 0.852, weighted-F1: 0.854
file: checkpoint_94.json, f1: 0.855, kappa: 0.854, weighted-F1: 0.855
file: checkpoint_95.json, f1: 0.855, kappa: 0.855, weighted-F1: 0.855
file: checkpoint_96.json, f1: 0.856, kappa: 0.853, weighted-F1: 0.856
file: checkpoint_97.json, f1: 0.854, kappa: 0.855, weighted-F1: 0.854
file: checkpoint_98.json, f1: 0.856, kappa: 0.853, weighted-F1: 0.856
file: checkpoint_99.json, f1: 0.854, kappa: 0.854, weighted-F1: 0.854
file: checkpoint_100.json, f1: 0.855, kappa: 0.857, weighted-F1: 0.855
Out[29]:
([0.85508576429972893, 0.85604497311060146, 0.85448674977052219],
 [array([[875,   4,  22,   8,   8,   0,   4,   9,  47,  23],
         [ 11, 932,   2,   5,   0,   0,   4,   3,  17,  26],
         [ 47,   0, 821,  18,  38,  24,  41,   6,   5,   0],
         [ 21,   1,  54, 723,  35,  79,  54,  18,  12,   3],
         [  8,   1,  33,  16, 875,   9,  43,  13,   1,   1],
         [  5,   1,  98, 126,  35, 679,  23,  32,   1,   0],
         [ 11,   1,  21,  10,  11,   2, 940,   1,   2,   1],
         [  6,   0,  15,  18,  68,  13,   3, 877,   0,   0],
         [ 14,   9,   3,  24,   4,   1,   1,   0, 938,   6],
         [ 15,  49,   2,   6,   4,   1,   3,   2,  16, 902]]),
  array([[872,   4,  23,  11,   8,   0,   5,   7,  47,  23],
         [ 11, 933,   2,   5,   0,   0,   3,   3,  17,  26],
         [ 45,   0, 821,  20,  36,  25,  40,   6,   6,   1],
         [ 20,   1,  44, 738,  31,  85,  50,  17,  12,   2],
         [  8,   1,  35,  19, 864,  12,  47,  12,   1,   1],
         [  5,   1,  83, 133,  32, 703,  19,  24,   0,   0],
         [ 11,   1,  24,  11,  10,   2, 937,   1,   2,   1],
         [  6,   0,  15,  23,  65,  14,   2, 875,   0,   0],
         [ 15,   9,   3,  26,   4,   1,   1,   0, 937,   4],
         [ 15,  57,   3,   7,   4,   1,   3,   1,  22, 887]]),
  array([[872,   4,  23,  10,   8,   0,   5,   7,  48,  23],
         [ 10, 933,   2,   6,   0,   0,   3,   3,  17,  26],
         [ 47,   0, 822,  20,  34,  23,  42,   6,   6,   0],
         [ 21,   1,  46, 739,  31,  79,  51,  17,  12,   3],
         [  8,   1,  39,  19, 861,  11,  47,  12,   1,   1],
         [  5,   1,  96, 133,  32, 681,  21,  30,   1,   0],
         [ 11,   1,  22,  11,  10,   2, 939,   1,   2,   1],
         [  6,   0,  15,  22,  65,  13,   3, 876,   0,   0],
         [ 13,   9,   2,  24,   4,   1,   1,   0, 941,   5],
         [ 15,  56,   2,   6,   4,   1,   3,   2,  21, 890]])],
 ['checkpoint_100.json', 'checkpoint_96.json', 'checkpoint_99.json'])